{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip3 install malaya\n",
    "\n",
    "import malaya\n",
    "import re\n",
    "from malaya.texts._text_functions import split_into_sentences\n",
    "from malaya.texts import _regex\n",
    "\n",
    "splitter = split_into_sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip\n",
    "# !unzip uncased_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from bert import tokenization\n",
    "\n",
    "BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'\n",
    "tokenizer = tokenization.FullTokenizer(vocab_file=BERT_VOCAB, do_lower_case=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "92579"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import glob\n",
    "\n",
    "stories = glob.glob('cnn/stories/*.story')\n",
    "len(stories)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_story(doc):\n",
    "    index = doc.find('@highlight')\n",
    "    story, highlights = doc[:index], doc[index:].split('@highlight')\n",
    "    highlights = [h.strip() for h in highlights if len(h) > 0]\n",
    "    stories = []\n",
    "    for s in splitter(story):\n",
    "        stories.append(s.split())\n",
    "    summaries = []\n",
    "    for s in highlights:\n",
    "        summaries.append(s.split())\n",
    "    return stories, summaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_src_nsents = 3\n",
    "max_src_nsents = 20\n",
    "min_src_ntokens_per_sent = 5\n",
    "max_src_ntokens_per_sent = 30\n",
    "min_tgt_ntokens = 5\n",
    "max_tgt_ntokens = 500\n",
    "sep_token = '[SEP]'\n",
    "cls_token = '[CLS]'\n",
    "pad_token = '[PAD]'\n",
    "tgt_bos = '[unused0]'\n",
    "tgt_eos = '[unused1]'\n",
    "tgt_sent_split = '[unused2]'\n",
    "sep_vid = tokenizer.vocab[sep_token]\n",
    "cls_vid = tokenizer.vocab[cls_token]\n",
    "pad_vid = tokenizer.vocab[pad_token]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(stories[0]) as fopen:\n",
    "    story = fopen.read()\n",
    "story, highlights = split_story(story)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _get_ngrams(n, text):\n",
    "    ngram_set = set()\n",
    "    text_length = len(text)\n",
    "    max_index_ngram_start = text_length - n\n",
    "    for i in range(max_index_ngram_start + 1):\n",
    "        ngram_set.add(tuple(text[i:i + n]))\n",
    "    return ngram_set\n",
    "\n",
    "\n",
    "def _get_word_ngrams(n, sentences):\n",
    "    assert len(sentences) > 0\n",
    "    assert n > 0\n",
    "\n",
    "    words = sum(sentences, [])\n",
    "    return _get_ngrams(n, words)\n",
    "\n",
    "def cal_rouge(evaluated_ngrams, reference_ngrams):\n",
    "    reference_count = len(reference_ngrams)\n",
    "    evaluated_count = len(evaluated_ngrams)\n",
    "\n",
    "    overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)\n",
    "    overlapping_count = len(overlapping_ngrams)\n",
    "\n",
    "    if evaluated_count == 0:\n",
    "        precision = 0.0\n",
    "    else:\n",
    "        precision = overlapping_count / evaluated_count\n",
    "\n",
    "    if reference_count == 0:\n",
    "        recall = 0.0\n",
    "    else:\n",
    "        recall = overlapping_count / reference_count\n",
    "\n",
    "    f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))\n",
    "    return {\"f\": f1_score, \"p\": precision, \"r\": recall}\n",
    "\n",
    "\n",
    "def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):\n",
    "    def _rouge_clean(s):\n",
    "        return re.sub(r'[^a-zA-Z0-9 ]', '', s)\n",
    "\n",
    "    max_rouge = 0.0\n",
    "    abstract = sum(abstract_sent_list, [])\n",
    "    abstract = _rouge_clean(' '.join(abstract)).split()\n",
    "    sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]\n",
    "    evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]\n",
    "    reference_1grams = _get_word_ngrams(1, [abstract])\n",
    "    evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]\n",
    "    reference_2grams = _get_word_ngrams(2, [abstract])\n",
    "\n",
    "    selected = []\n",
    "    for s in range(summary_size):\n",
    "        cur_max_rouge = max_rouge\n",
    "        cur_id = -1\n",
    "        for i in range(len(sents)):\n",
    "            if (i in selected):\n",
    "                continue\n",
    "            c = selected + [i]\n",
    "            candidates_1 = [evaluated_1grams[idx] for idx in c]\n",
    "            candidates_1 = set.union(*map(set, candidates_1))\n",
    "            candidates_2 = [evaluated_2grams[idx] for idx in c]\n",
    "            candidates_2 = set.union(*map(set, candidates_2))\n",
    "            rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']\n",
    "            rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']\n",
    "            rouge_score = rouge_1 + rouge_2\n",
    "            if rouge_score > cur_max_rouge:\n",
    "                cur_max_rouge = rouge_score\n",
    "                cur_id = i\n",
    "        if (cur_id == -1):\n",
    "            return selected\n",
    "        selected.append(cur_id)\n",
    "        max_rouge = cur_max_rouge\n",
    "\n",
    "    return sorted(selected)\n",
    "\n",
    "def get_xy(story, highlights):\n",
    "    idxs = [i for i, s in enumerate(story) if (len(s) > min_src_ntokens_per_sent)]\n",
    "    \n",
    "    idxs = [i for i, s in enumerate(story) if (len(s) > min_src_ntokens_per_sent)]\n",
    "\n",
    "    src = [story[i][:max_src_ntokens_per_sent] for i in idxs]\n",
    "    src = src[:max_src_nsents]\n",
    "\n",
    "    sent_labels = greedy_selection(src, highlights, 3)\n",
    "\n",
    "    _sent_labels = [0] * len(src)\n",
    "    for l in sent_labels:\n",
    "        _sent_labels[l] = 1\n",
    "    _sent_labels\n",
    "    \n",
    "    src_txt = [' '.join(sent) for sent in src]\n",
    "    src_subtokens = []\n",
    "    for i, text in enumerate(src_txt):\n",
    "        text = tokenizer.tokenize(text)\n",
    "        if i > 0:\n",
    "            text = ['[SEP]','[CLS]'] + text\n",
    "        src_subtokens.extend(text)\n",
    "    \n",
    "    src_subtokens = [cls_token] + src_subtokens + [sep_token]\n",
    "    src_subtoken_idxs = tokenizer.convert_tokens_to_ids(src_subtokens)\n",
    "    \n",
    "    _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == sep_vid]\n",
    "    segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]\n",
    "    segments_ids = []\n",
    "    for i, s in enumerate(segs):\n",
    "        if (i % 2 == 0):\n",
    "            segments_ids += s * [0]\n",
    "        else:\n",
    "            segments_ids += s * [1]\n",
    "    cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == cls_vid]\n",
    "    \n",
    "    return src_subtoken_idxs, cls_ids, _sent_labels, segments_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(stories[1]) as fopen:\n",
    "    story = fopen.read()\n",
    "story, highlights = split_story(story)\n",
    "text, cls_ids, sent_labels, segments_ids = get_xy(story, highlights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(20, 20, 661, 661)"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(sent_labels), len(cls_ids), len(text), len(segments_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 92579/92579 [13:52<00:00, 111.15it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "texts, clss, labels, segments = [], [], [], []\n",
    "\n",
    "for i in tqdm(range(len(stories))):\n",
    "    with open(stories[i]) as fopen:\n",
    "        story = fopen.read()\n",
    "    story, highlights = split_story(story)\n",
    "    text, cls_ids, sent_labels, segments_ids = get_xy(story, highlights)\n",
    "    if len(cls_ids) != len(sent_labels):\n",
    "        continue\n",
    "    texts.append(text)\n",
    "    clss.append(cls_ids)\n",
    "    labels.append(sent_labels)\n",
    "    segments.append(segments_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_texts, test_texts, train_clss, test_clss, train_labels, test_labels, train_segments, test_segments = \\\n",
    "train_test_split(texts, clss, labels, segments, test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open('dataset-bert.pkl', 'wb') as fopen:\n",
    "    pickle.dump({'train_texts': train_texts,\n",
    "                'test_texts': test_texts,\n",
    "                'train_clss': train_clss,\n",
    "                'test_clss': test_clss,\n",
    "                'train_labels': train_labels,\n",
    "                'test_labels': test_labels,\n",
    "                'train_segments': train_segments,\n",
    "                'test_segments': test_segments}, fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
